Skip to content

fix: attach CP attention-mask hooks for dense (non-TE) context parallelism#1470

Open
hemildesai wants to merge 12 commits intomainfrom
hemil/cp-dense-fixes
Open

fix: attach CP attention-mask hooks for dense (non-TE) context parallelism#1470
hemildesai wants to merge 12 commits intomainfrom
hemil/cp-dense-fixes

Conversation

@hemildesai
Copy link
Contributor

Summary

  • Add _attach_context_parallel_hooks to register forward pre-hooks on self_attn modules that strip attention_mask and set is_causal=True, fixing shape mismatches when dense (non-TE) context parallelism shards Q/K/V as DTensors
  • Call the hooks in TrainFinetuneRecipeForNextTokenPrediction when cp_size > 1 and TE attention is not used
  • Add unit tests for the new hook function and the attention_mask removal in make_cp_batch_and_ctx

Test plan

  • Unit tests pass: pytest tests/unit_tests/distributed/test_cp_utils.py (12 tests, all passing)
  • Manual validation with dense CP training (non-TE backend)

🤖 Generated with Claude Code

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai
Copy link
Contributor Author

/ok to test e3fb07e

@hemildesai
Copy link
Contributor Author

/ok to test 5165660

@hemildesai
Copy link
Contributor Author

/ok to test 274666d

@akoumpa
Copy link
Contributor

akoumpa commented Mar 7, 2026

/ok to test 9098ba2

@hemildesai
Copy link
Contributor Author

/ok to test 32e55a4

hemildesai and others added 9 commits March 8, 2026 17:54
…elism

Strip the 4D attention_mask from the batch and register forward pre-hooks
on self_attn modules to set is_causal=True, so that SDPA handles causal
masking internally when using dense context parallelism without TE.

Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Replace functools.partial(F.scaled_dot_product_attention, ...) with a
closure that resolves F.scaled_dot_product_attention at call time. This
ensures CP's runtime monkey-patch of the function is picked up by all
custom models instead of being bypassed by the early-bound reference.

Also make _attach_context_parallel_hooks public (renamed to
attach_context_parallel_hooks).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…ends

Extract SDPA backend selection into a resolve_sdpa_method() helper that
accepts string names from YAML config (e.g. ["flash_attention",
"efficient_attention"]) and converts them to SDPBackend enum members.
When no explicit config is provided, auto-selects based on CP and
activation checkpointing constraints.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Replace the assert that required all attention modules to be TE
DotProductAttention with a continue, so dense (SDPA) attention
modules are gracefully skipped. This allows MoE models to use
context parallelism with non-TE attention backends.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Move the resolve_sdpa_method helper from train_ft.py to
_transformers/model_init.py per review feedback. The config
resolution (reading sdpa_method from YAML and passing it to
build_model) remains in train_ft.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Move the attach_context_parallel_hooks call from train_ft.py into
apply_model_infrastructure in infrastructure.py, which already has
access to the device mesh. Add _uses_te_attention helper that inspects
the model's self_attn.attn_module instances to determine if TE
DotProductAttention is used, replacing the config-based check.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
When the model is an AutoPipeline, iterate over model.parts to inspect
self_attn modules instead of only the pipeline wrapper itself.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Move the resolve_sdpa_method call from build_model in train_ft.py into
from_pretrained/from_config in auto_model.py, where device_mesh and
activation_checkpointing are available. train_ft.py now passes raw
string values from YAML as sdpa_method without pre-resolving.

resolve_sdpa_method now accepts both string names and SDPBackend enum
values, making the API flexible for both YAML config and programmatic
use.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai
Copy link
Contributor Author

/ok to test d7db759

SDPA backend patching only runs inside NeMoAutoModel._build_model, so
sdpa_method has no effect for custom model builders. Log a warning to
avoid silently dropping the user's config.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai
Copy link
Contributor Author

/ok to test 262b857

The SDPA attn_func changed from functools.partial to a closure,
so .keywords no longer exists. Mock F.scaled_dot_product_attention
and inspect call kwargs instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Contributor Author

/ok to test 3ff488e

@adil-a
Copy link
Collaborator

adil-a commented Mar 9, 2026

/ok to test aa685e0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants